#################################################################################
# Copyright (c) 2018-2021, Texas Instruments Incorporated - http://www.ti.com
# All Rights Reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
#   list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
#   this list of conditions and the following disclaimer in the documentation
#   and/or other materials provided with the distribution.
#
# * Neither the name of the copyright holder nor the names of its
#   contributors may be used to endorse or promote products derived from
#   this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################
# Some parts of the code are borrowed from: https://github.com/ansleliu/LightNet
# with the following license:
#
# MIT License
#
# Copyright (c) 2018 Huijun Liu
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

import numpy as np
import os
import scipy.misc as misc
import sys
import cv2
__package__ = "pytorch_jacinto_ai.xvision.datasets.pixel2pixel"
from ... import xnn
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec

from .cityscapes_plus import CityscapesBaseSegmentationLoader, CityscapesBaseMotionLoader
from .a2d2 import A2D2BaseSegmentationLoader, A2D2BaseMotionLoader

def calc_median_frequency(classes, present_num):
    """
    Class balancing by median frequency balancing method.
    Reference: https://arxiv.org/pdf/1411.4734.pdf
       'a = median_freq / freq(c) where freq(c) is the number of pixels
        of class c divided by the total number of pixels in images where
        c is present, and median_freq is the median of these frequencies.'
    """
    class_freq = classes / present_num
    median_freq = np.median(class_freq[classes != 1.0])
    return median_freq / class_freq

def calc_log_frequency(classes, value=1.02):
    """Class balancing by ERFNet method.
       prob = each_sum_pixel / each_sum_pixel.max()
       a = 1 / (log(1.02 + prob)).
    """
    class_freq = classes / classes.sum()  # ERFNet is max, but ERFNet is sum
    # print(class_freq)
    # print(np.log(value + class_freq))
    return 1 / np.log(value + class_freq)

def print_stats(classes = [], class_weight = []):
    print("class_freq \n","-"*32)
    for idx, class_freq in enumerate(classes):
        print("{} : {:.0f}".format(idx, class_freq))
    print("-"*32)
    print("class_freq in % \n","-"*32)
    for idx, class_freq in enumerate(classes):
        print("{} : {:.2f}".format(idx, class_freq*100.0/np.sum(classes)))
    print("-"*32)

    print("-"*32)
    print("class weights \n","-"*32)
    for idx, class_wt in enumerate(class_weight):
        print("{} : {:.08}".format(idx, class_wt))


def study_frame_level_stat(frame_level_stat = [], frame_names = []):
    num_images = frame_level_stat.shape[1]
    num_classes = frame_level_stat.shape[0]
    op_path = '/data/ssd/datasets/a2d2_v2/info/frame_level_stat/plot/'
    if not os.path.exists(op_path):
        os.makedirs(op_path)
    
    #show only difficult or all classes
    show_difficult_class  = True
    if show_difficult_class:
        classes_under_study = [11, 13, 17, 18, 19, 22, 25] 
        num_col_plots = 1
    else:
        num_col_plots = 2
        classes_under_study = [x for x in range(num_classes)]
        
    fig, axs = plt.subplots(int(len(classes_under_study)/num_col_plots), num_col_plots)
    fig.suptitle('Frame by frame stats')

    for plot_idx, class_idx in enumerate(classes_under_study):
        print("class_idx: ", class_idx)
    
        if num_col_plots == 1:
          axs[plot_idx].plot(frame_level_stat[class_idx])
          axs[plot_idx].set_title(class_idx, loc='right')
        else:  
          axs[int(class_idx/num_col_plots), class_idx%num_col_plots].plot(frame_level_stat[class_idx])
          axs[int(class_idx/num_col_plots), class_idx%num_col_plots].set_title(class_idx)
          
        avg_class = np.average(frame_level_stat[class_idx][:])
        for frm_idx in range(num_images):
            if frame_level_stat[class_idx][frm_idx] > (avg_class*2):
                print("frm, occurance", frame_names[frm_idx], frame_level_stat[class_idx][frm_idx])
    plt.savefig('{}/frame_by_frame_stats.jpg'.format(op_path))       

def calc_weights():    
    method = "median"
    result_path = "/data/ssd/datasets/a2d2_v2/info/"

    traval = "gtFine"
    #imgs_path = "/data/ssd/datasets/a2d2_v1_full/leftImg8bit/train"    #"./data/cityscapes/data/leftImg8bit/train"   #"./data/TIAD/data/leftImg8bit/train"
    lbls_path = "/data/ssd/datasets/a2d2_v2/gtFine/train/"         #"./data/cityscapes/data/gtFine/train"   # "./data/tiad/data/gtFine/train"  #"./data/cityscapes_frame_pair/data/gtFine/train"
    #labels = xnn.utils.recursive_glob(rootdir=lbls_path, suffix='labelTrainIds_motion.png')  #'labelTrainIds_motion.png'  #'labelTrainIds.png'
    labels = xnn.utils.recursive_glob(rootdir=lbls_path, suffix='.png')  #'labelTrainIds_motion.png'  #'labelTrainIds.png'

    num_classes = 38       #5  #2

    local_path = "./data/checkpoints"
    dst = A2D2BaseSegmentationLoader() #TiadBaseSegmentationLoader()  #CityscapesBaseSegmentationLoader()  #CityscapesBaseMotionLoader(), #A2D2BaseSegmentationLoader()

    classes, present_num = ([0 for i in range(num_classes)] for i in range(2))

    #Frame by frame stat
    #stat[class_idx][frm_idx]
    frame_level_stat = np.zeros( (len(classes),len(labels)) ) 
    frame_names = []
    for idx, lbl_path in enumerate(labels):
        if not (idx%100):
          print("idx: ", idx)
        lbl = cv2.imread(lbl_path, 0)
        lbl = dst.encode_segmap(np.array(lbl, dtype=np.uint8))
        frame_names.append(lbl_path)
        for nc in range(num_classes):
            num_pixel = (lbl == nc).sum()
            frame_level_stat[nc][idx] = num_pixel    
            if num_pixel:
                classes[nc] += num_pixel
                present_num[nc] += 1

    classes = np.array(classes, dtype="f")
    np.save("train_frame_level_stat", frame_level_stat)
    np.save("train_frame_names", np.array(frame_names))
    study_frame_level_stat(frame_level_stat = frame_level_stat, frame_names = frame_names)
    
    #if any class had 0 occurnace then set to 1 to avoid div by 0 kind of error
    classes[classes==0] = 1

    presetn_num = np.array(classes, dtype="f")
    if method == "median":
        class_weight = calc_median_frequency(classes, present_num)
    elif method == "log":
        class_weight = calc_log_frequency(classes)
    else:
        raise Exception("Please assign method to 'mean' or 'log'")
    
    print_stats(classes = classes, class_weight = class_weight)

    print("Done!")


def study_stats():
    frame_level_stat = np.load("/data/ssd/datasets/a2d2_v2/info/frame_level_stat/train_frame_level_stat.npy")
    frame_names = np.load("/data/ssd/datasets/a2d2_v2/info/frame_level_stat/train_frame_names.npy")
    study_frame_level_stat(frame_level_stat = frame_level_stat, frame_names = frame_names)

if __name__ == '__main__':
    #calc_weights()
    study_stats()